import logging
import os
from collections import defaultdict
from timeit import default_timer

import brewer2mpl
import wandb
from disentanglement_lib.evaluation.metrics import mig
from tqdm import tqdm

# from disvae import mig, sap, modularity
from disvae.utils.modelIO import save_metadata

bmap = brewer2mpl.get_map('Set1', 'qualitative', 3)
colors = bmap.mpl_colors

TEST_LOSSES_FILE = "test_losses.log"
METRICS_FILENAME = "metrics.log"
METRIC_HELPERS_FILE = "metric_helpers.pth"
VAR_THRESHOLD = 0.15

import math
import gin
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
eps = 1e-8
gin.parse_config([
            "mig.num_train=10000",
            "discretizer.discretizer_fn = @histogram_discretizer",
            "discretizer.num_bins = 20",

            "factor_vae_score.num_variance_estimate=10000",
            "factor_vae_score.num_train=10000",
            "factor_vae_score.num_eval=5000",
            "factor_vae_score.batch_size=64",
            "prune_dims.threshold = 0.05",

            "modularity_explicitness.num_train=10000",
            "modularity_explicitness.num_test=5000",

            "sap_score.num_train=10000",
            "sap_score.num_test=5000",
    "sap_score.continuous_factors=False",

    "dci.num_train=10000",
    "dci.num_test=5000"
        ],True)

class Evaluator:
    """
    Class to handle training of model.

    Parameters
    ----------
    model: disvae.vae.VAE

    loss_f: disvae.models.BaseLoss
        Loss function.

    device: torch.device, optional
        Device on which to run the code.

    logger: logging.Logger, optional
        Logger.

    save_dir : str, optional
        Directory for saving logs.

    is_progress_bar: bool, optional
        Whether to use a progress bar for training.
    """

    def __init__(self, model, loss_f,
                 device=torch.device("cpu"),
                 logger=logging.getLogger(__name__),
                 save_dir="results",
                 is_progress_bar=True):

        self.device = device
        self.loss_f = loss_f
        self.model = model.to(self.device)
        self.logger = logger
        self.save_dir = save_dir
        self.is_progress_bar = is_progress_bar
        self.logger.info("Testing Device: {}".format(self.device))



    def __call__(self, data_loader, is_metrics=False, is_losses=True):
        """Compute all test losses.

        Parameters
        ----------
        is_metrics: bool, optional
            Whether to compute and store the disentangling metrics.

        is_losses: bool, optional
            Whether to compute and store the test losses.
        """
        start = default_timer()
        is_still_training = self.model.training
        self.model.eval()

        metric, losses = None, None
        if is_metrics:
            self.logger.info('Computing metrics...')
            _, _, metrics = self.compute_metrics(data_loader)
            self.logger.info('Losses: {}'.format(metrics))
            save_metadata(metrics, self.save_dir, filename=METRICS_FILENAME)

        if is_losses:
            self.logger.info('Computing losses...')
            losses = self.compute_losses(data_loader)
            self.logger.info('Losses: {}'.format(losses))
            save_metadata(losses, self.save_dir, filename=TEST_LOSSES_FILE)

        if is_still_training:
            self.model.train()

        self.logger.info('Finished evaluating after {:.1f} min.'.format((default_timer() - start) / 60))

        return metric, losses

    def compute_losses(self, dataloader):
        """Compute all test losses.

        Parameters
        ----------
        data_loader: torch.utils.data.DataLoader
        """
        storer = defaultdict(list)
        for data, _ in tqdm(dataloader, leave=False, disable=not self.is_progress_bar):
            data = data.to(self.device)

            try:
                recon_batch, latent_dist, latent_sample = self.model(data)
                _ = self.loss_f(data, recon_batch, latent_dist, self.model.training,
                                storer, latent_sample=latent_sample)
            except ValueError:
                # for losses that use multiple optimizers (e.g. Factor)
                _ = self.loss_f.call_optimize(data, self.model, None, storer)

        losses = {k: sum(v) / len(v) for k, v in storer.items()}
        if wandb.run:
            wandb.log(losses)
        return losses

    def compute_metrics(self, dataloader):
        """Compute all the metrics.

        Parameters
        ----------
        data_loader: torch.utils.data.DataLoader
        """
        self.logger.info("Computing the empirical distribution q(z|x).")

        metric_helpers = dict()
        for score in [mig.compute_mig]:
            metric = score(dataloader.dataset, self.representation_func,np.random.RandomState())
            metric_helpers.update(metric)

        print(metric_helpers)
        torch.save(metric_helpers, os.path.join(self.save_dir, METRIC_HELPERS_FILE))
        lat_sizes = dataloader.dataset.factors_num_values
        params_zCx, labels = self.compute(dataloader)
        if wandb.run:
            ndict = metric_helpers.copy()
            try:
                self._plot_latent_vs_ground(params_zCx, ndict, latnt_sizes=lat_sizes)
            except:
                pass
            wandb.log(ndict)
        else:
            self._plot_latent_vs_ground(params_zCx, None, latnt_sizes=lat_sizes)


        return params_zCx, labels, metric_helpers

    def _plot_latent_vs_ground(self, param, n_dict=None, z_inds=None,latnt_sizes=[3, 6, 40, 32, 32]):
        import matplotlib.pyplot as plt
        K = param[0].shape[-1]
        qz_means = param[0].view(*(latnt_sizes+[K])).cpu().data
        var = torch.std(qz_means.contiguous().view(-1, K), dim=0).pow(2)
        active_units = torch.arange(0, K)[var > VAR_THRESHOLD].long().tolist()
        active_units = [1, 2, 3]
        if z_inds is None:
            z_inds = active_units
        print('Active units: ' + ','.join(map(str, z_inds)))
        n_active = len(z_inds)
        print('Number of active units: {}/{}'.format(n_active, K))
        num_factor = len(qz_means.shape) - 1

        fig, axes = plt.subplots(n_active, num_factor,
                                 figsize=(num_factor * 3, (n_active + 1) * 3),
                                 squeeze=False)  # default is (8,6)

        for j in range(num_factor):
            mean_latent = qz_means.mean(dim=[dim for dim in range(num_factor) if dim != j])
            # y_min = mean_latent[:, z_inds].min().item()
            # y_max = mean_latent[:, z_inds].max().item()
            for ax, i in zip(axes[:, j], z_inds):
                ax.plot(mean_latent[:, i].numpy(), )
                # ax.set_xticks([])
                # ax.set_yticks([])
                # ax.set_ylim(y_min, y_max)
                x0, x1 = ax.get_xlim()
                y0, y1 = ax.get_ylim()
                ax.set_aspect(abs(x1 - x0) / abs(y1 - y0))
                if i == z_inds[-1]:
                    ax.set_xlabel(f'c_{j}')
                if j == 0:
                    ax.set_ylabel(f'z_{i}')
                    # ax.yaxis.tick_right()

        fig.text(0.5, 0.03, 'Ground Truth', ha='center')
        fig.text(0.01, 0.5, 'Learned Latent Variables ', va='center', rotation='vertical')
        if n_dict:
            n_dict['gt_vs_latent'] = wandb.Image(fig)
        fig.savefig(os.path.join(self.save_dir, 'gt_vs_latent.svg'))
        plt.close()

    def representation_func(self, x):
        x = torch.tensor(x.transpose([0,3,1,2]),dtype=torch.float32,device=self.device)
        with torch.no_grad():
            mean, logvar = self.model.encoder(x)
        return mean.cpu().numpy()

    def compute(self, dataloader):
        """Compute the empiricall disitribution of q(z|x).

        Parameter
        ---------
        dataloader: torch.utils.data.DataLoader
            Batch data iterator.

        Return
        ------
        samples_zCx: torch.tensor
            Tensor of shape (len_dataset, latent_dim) containing a sample of
            q(z|x) for every x in the dataset.

        params_zCX: tuple of torch.Tensor
            Sufficient statistics q(z|x) for each training example. E.g. for
            gaussian (mean, log_var) each of shape : (len_dataset, latent_dim).

        recons: torch.tensor
            reconstruction images.

        labels: torch.tensor
            ground-truth factors.
        """
        len_dataset = len(dataloader.dataset)
        latent_dim = self.model.latent_dim
        n_suff_stat = 2

        q_zCx = torch.zeros(len_dataset, latent_dim, n_suff_stat, device=self.device)
        labels = []
        n = 0
        with torch.no_grad():
            for x, label in dataloader:
                batch_size = x.size(0)
                idcs = slice(n, n + batch_size)
                x = x.float()
                q_zCx[idcs, :, 0], q_zCx[idcs, :, 1] = self.model.encoder(x.to(self.device))
                z = self.model.reparameterize(q_zCx[idcs, :, 0], q_zCx[idcs, :, 1])

                labels.append(label)
                n += batch_size

        params_zCX = q_zCx.unbind(-1)
        labels = torch.cat(labels)
        return params_zCX, labels


class Normal(nn.Module):
    """Samples from a Normal distribution using the reparameterization trick.
    """

    def __init__(self, mu=0, sigma=1):
        super(Normal, self).__init__()
        self.normalization = Variable(torch.Tensor([np.log(2 * np.pi)]))

        self.mu = Variable(torch.Tensor([mu]))
        self.logsigma = Variable(torch.Tensor([math.log(sigma)]))

    def _check_inputs(self, size, mu_logsigma):
        if size is None and mu_logsigma is None:
            raise ValueError(
                'Either one of size or params should be provided.')
        elif size is not None and mu_logsigma is not None:
            mu = mu_logsigma.select(-1, 0).expand(size)
            logsigma = mu_logsigma.select(-1, 1).expand(size)
            return mu, logsigma
        elif size is not None:
            mu = self.mu.expand(size)
            logsigma = self.logsigma.expand(size)
            return mu, logsigma
        elif mu_logsigma is not None:
            mu = mu_logsigma.select(-1, 0)
            logsigma = mu_logsigma.select(-1, 1)
            return mu, logsigma
        else:
            raise ValueError(
                'Given invalid inputs: size={}, mu_logsigma={})'.format(
                    size, mu_logsigma))

    def sample(self, size=None, params=None):
        mu, logsigma = self._check_inputs(size, params)
        std_z = Variable(torch.randn(mu.size()).type_as(mu.data))
        sample = std_z * torch.exp(logsigma) + mu
        return sample

    def log_density(self, sample, params=None):
        if params is not None:
            mu, logsigma = self._check_inputs(None, params)
        else:
            mu, logsigma = self._check_inputs(sample.size(), None)
            mu = mu.type_as(sample)
            logsigma = logsigma.type_as(sample)

        c = self.normalization.type_as(sample.data)
        inv_sigma = torch.exp(-logsigma)
        tmp = (sample - mu) * inv_sigma
        return -0.5 * (tmp * tmp + 2 * logsigma + c)

    def NLL(self, params, sample_params=None):
        """Analytically computes
            E_N(mu_2,sigma_2^2) [ - log N(mu_1, sigma_1^2) ]
        If mu_2, and sigma_2^2 are not provided, defaults to entropy.
        """
        mu, logsigma = self._check_inputs(None, params)
        if sample_params is not None:
            sample_mu, sample_logsigma = self._check_inputs(None, sample_params)
        else:
            sample_mu, sample_logsigma = mu, logsigma

        c = self.normalization.type_as(sample_mu.data)
        nll = logsigma.mul(-2).exp() * (sample_mu - mu).pow(2) \
              + torch.exp(sample_logsigma.mul(2) - logsigma.mul(2)) + 2 * logsigma + c
        return nll.mul(0.5)

    def kld(self, params):
        """Computes KL(q||p) where q is the given distribution and p
        is the standard Normal distribution.
        """
        mu, logsigma = self._check_inputs(None, params)
        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mean^2 - sigma^2)
        kld = logsigma.mul(2).add(1) - mu.pow(2) - logsigma.exp().pow(2)
        kld.mul_(-0.5)
        return kld

    def get_params(self):
        return torch.cat([self.mu, self.logsigma])

    @property
    def nparams(self):
        return 2

    @property
    def ndim(self):
        return 1

    @property
    def is_reparameterizable(self):
        return True

    def __repr__(self):
        tmpstr = self.__class__.__name__ + ' ({:.3f}, {:.3f})'.format(
            self.mu.item(), self.logsigma.exp().item())
        return tmpstr
